import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms

class HookModule:
    def __init__(self, model, module):
        self.model = model
        self.handle = module.register_forward_hook(self._get_output)
        
    def _get_output(self, module, inputs, outputs):
        self.outputs = outputs
    
    def grads(self, outputs, retain_graph=True, create_graph=True):
        grads = torch.autograd.grad(outputs=outputs, inputs=self.outputs, retain_graph=retain_graph, create_graph=create_graph)
        self.model.zero_grad()
        return grads[0]
    
    def remove(self):
        self.handle.remove()


class GradConstraint:

    def __init__(self, model, modules, channel_paths, flag_3d=False, relu=False):
        self.model = model
        self.modules = modules
        self.hook_modules = []
        self.channels = []
        self.flag_3d = flag_3d
        self.relu = relu

        for channel_path in channel_paths:
            self.channels.append(torch.from_numpy(np.load(channel_path)).cuda())

    def add_hook(self):
        for module in self.modules:
            self.hook_modules.append(HookModule(model=self.model, module=module))

    def remove_hook(self):
        for hook_module in self.hook_modules:
            hook_module.remove()
        self.hook_modules.clear()

    def loss_spatial(self, outputs, labels, masks):
        # nll_loss = torch.nn.NLLLoss()(outputs, labels)
        # grads = self.hook_modules[0].grads(outputs=-nll_loss)
        # masks = transforms.Resize((grads.shape[2], grads.shape[3]))(masks)
        # masks_bg = 1 - masks
        # grads_bg = torch.abs(masks_bg * grads)

        # loss = grads_bg.sum()
        # return loss
        nll_loss = torch.nn.NLLLoss()(outputs, labels)
        loss = 0
        for hook_module in self.hook_modules:
            grads = hook_module.grads(outputs=-nll_loss)
            if self.flag_3d:
                masks_bg = F.interpolate(masks, (grads.shape[2], grads.shape[3], grads.shape[4]), mode='trilinear')
            else:
                masks_bg = transforms.Resize((grads.shape[2], grads.shape[3]))(masks)
            masks_bg = 1 - masks_bg
            if self.relu:
                grads_bg = F.relu(masks_bg * grads)
            else:
                grads_bg = torch.abs(masks_bg * grads)
            loss += grads_bg.sum()
        return loss

    def loss_channel(self, outputs, labels):
        # high response channel loss
        probs = torch.argsort(-outputs, dim=1)
        labels_ = []
        for i in range(labels.size(0)):
            if probs[i][0] == labels[i]:
                labels_.append(probs[i][1])  # TP rank2
            else:
                labels_.append(probs[i][0])  # FP rank1
        labels_ = torch.tensor(labels_).cuda()
        nll_loss_ = torch.nn.NLLLoss()(outputs, labels_)
        # low response channel loss
        nll_loss = torch.nn.NLLLoss()(outputs, labels)

        loss = 0
        for i, hook_module in enumerate(self.hook_modules):
            # high response channel loss
            loss += self._loss_channel(channels=self.channels[i],
                                  grads=hook_module.grads(outputs=-nll_loss_),
                                  labels=labels_,
                                  is_high=True)

            # low response channel loss
            loss += self._loss_channel(channels=self.channels[i],
                                  grads=hook_module.grads(outputs=-nll_loss),
                                  labels=labels,
                                  is_high=False)
            break #只处理第一个module
        return loss

    def _loss_channel(self, channels, grads, labels, is_high=True):
        if self.relu:
            grads = F.relu(grads) 
        else:
            grads = torch.abs(grads)
        channel_grads = torch.sum(grads, dim=(2, 3)) if not self.flag_3d else torch.sum(grads, dim=(2, 3, 4))  # [batch_size, channels]

        loss = 0
        if is_high:
            for b, l in enumerate(labels):
                loss += (channel_grads[b] * channels[l]).sum() #if l in [2, 3] else torch.tensor(0.).to(channel_grads.device)
        else:
            for b, l in enumerate(labels):
                loss += (channel_grads[b] * (1 - channels[l])).sum() #if l in [2, 3] else torch.tensor(0.).to(channel_grads.device)
        loss = loss / labels.size(0)
        return loss

    def loss_var(self, labels, base_label=1, loss_label=2): #用来增加壁内血肿特征的多样性，将壁内血肿特征向量关键元素的方差和主动脉夹层的对齐，但是效果不好遂舍弃
        loss = torch.tensor(0., dtype=torch.float32, device=labels.device)
        for i, hook_module in enumerate(self.hook_modules):
            base_feas = hook_module.outputs[labels==base_label].detach().mean(dim=(2,3,4) if self.flag_3d else (2,3))
            loss_feas = hook_module.outputs[labels==loss_label].mean(dim=(2,3,4) if self.flag_3d else (2,3))
            if base_feas.shape[0] == 0 or loss_feas.shape[0] == 0:
                break
            base_var = torch.var(base_feas[:,self.channels[i][base_label].bool()], dim=0, unbiased=False).mean()
            loss_var = torch.var(loss_feas[:,self.channels[i][loss_label].bool()], dim=0, unbiased=False).mean()
            loss += F.relu(base_var-loss_var)
            break #只处理第一个module
        return loss


class GradIntegral:
    def __init__(self, model, modules):
        self.modules = modules
        self.hooks = []

    def add_noise(self):
        for module in self.modules:
            hook = module.register_forward_hook(self._modify_feature_map)
            self.hooks.append(hook)

    def remove_noise(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

    # keep forward after modify
    @staticmethod
    def _modify_feature_map(module, inputs, outputs):
        noise = torch.randn(outputs.shape).to(outputs.device)
        # noise = torch.normal(mean=0, std=3, size=outputs.shape).to(outputs.device)
        outputs += noise
        return outputs
